package base.mapper;

import base.PageVO;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ReflectUtil;
import constants.MapperConstant;
import lombok.SneakyThrows;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class BaseMapper<T> implements MapperX {

    private Connection conn;

    protected Class<T> clazz;

    protected Field[] fields;

    protected String tableName;

    protected String fieldNames;

    public BaseMapper(Class<T> clazz) {
        this.clazz = clazz;
        this.fields = ReflectUtil.getFields(clazz);
        Arrays.stream(this.fields).forEach(i -> i.setAccessible(true));
        this.tableName = MapperX.getTableName(clazz);
        this.fieldNames = MapperX.getFieldNames(clazz);
    }

    @SneakyThrows
    public PageVO<T> page(int pageNum, int pageSize) {
        String sql = MapperConstant.getSelectPrefix(tableName, fieldNames) + " LIMIT " + ((pageNum - 1) * pageSize) + "," + pageSize;
        List<T> resultList = new ArrayList<>();
        try (PreparedStatement prepareStatement = conn.prepareStatement(sql)) {
            prepareStatement.execute(sql);
            ResultSet resultSet = prepareStatement.getResultSet();
            while (resultSet.next()) {
                T entity = ReflectUtil.getConstructor(clazz).newInstance();
                buildEntity(entity, resultSet);
                resultList.add(entity);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        PageVO<T> pageVO = new PageVO<>();
        try (Statement count = conn.createStatement()) {
            count.execute(MapperConstant.getSelectCountSql(tableName));
            ResultSet countResult = count.getResultSet();
            countResult.next();
            pageVO.setPageNum(pageNum);
            pageVO.setPageSize(pageSize);
            pageVO.setRecords(resultList);
            pageVO.setTotal(countResult.getInt(1));
            BigDecimal total = new BigDecimal(pageVO.getTotal());
            if (total.compareTo(BigDecimal.ZERO) == 0) {
                pageVO.setPages(0);
            } else {
                pageVO.setPages(total.divide(new BigDecimal(pageSize), RoundingMode.UP).intValue());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return pageVO;
    }

    @SneakyThrows
    private void buildEntity(T entity, ResultSet resultSet) {
        for (int i = 0; i < fields.length; i++) {
            Field field = fields[i];
            field.setAccessible(true);
            ReflectUtil.setFieldValue(entity, field, resultSet.getObject(field.getName()));
        }
    }

    @SneakyThrows
    public void save(T entity) {
        if (BeanUtil.isEmpty(entity)) {
            throw new IllegalArgumentException("保存行数据，参数不能为空");
        }
        int paramLen = fields.length;
        String sql = MapperConstant.INSERT_PREFIX + tableName + "(" + fieldNames + ") VALUES (" + MapperX.getParams(paramLen) + ")";
        try (PreparedStatement prepareStatement = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) {
            for (int i = 0; i < paramLen; i++) {
                prepareStatement.setObject(i + 1, fields[i].get(entity));
            }
            prepareStatement.execute();
        } catch (Exception e) {
            conn.rollback();
            e.printStackTrace();
        }
    }

    @SneakyThrows
    public void removeById(Long id) {
        if (ObjUtil.isEmpty(id)) {
            throw new IllegalArgumentException("通过id删除行数据，参数不能为空");
        }
        String sql = MapperConstant.DELETE_PREFIX + tableName + " WHERE id = " + id;
        try (PreparedStatement prepareStatement = conn.prepareStatement(sql)) {
            prepareStatement.execute();
        } catch (Exception e) {
            conn.rollback();
            e.printStackTrace();
        }
    }

    @SneakyThrows
    public T detail(Long id) {
        if (ObjUtil.isEmpty(id)) {
            throw new IllegalArgumentException("通过id查询行数据，参数不能为空");
        }
        String sql = MapperConstant.getSelectPrefix(tableName, fieldNames) + " WHERE id = " + id;
        try (PreparedStatement prepareStatement = conn.prepareStatement(sql)) {
            prepareStatement.execute();
            ResultSet resultSet = prepareStatement.getResultSet();
            if (resultSet.getRow() > 1) {
                throw new RuntimeException("通过id查询行数据，结果集大于1");
            }
            if (resultSet.next()) {
                T entity = ReflectUtil.getConstructor(clazz).newInstance();
                buildEntity(entity, resultSet);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    @SneakyThrows
    public void updateById(T entity) {
        if (ObjUtil.isEmpty(entity)) {
            throw new IllegalArgumentException("通过id更新行数据，参数不能为空");
        }
        if (ObjUtil.isEmpty(fields[0])) {
            throw new IllegalArgumentException("通过id更新行数据，id不能为空");
        }
        // 组装参数
        StringBuilder sqlParams = new StringBuilder();
        List<Field> updateFields = new ArrayList<>();
        for (Field field : fields) {
            if (field.get(entity) == null || field.getName().equals("id")) {
                continue;
            }
            sqlParams.append(field.getName()).append(" = ").append("?").append(",");
            updateFields.add(field);
        }
        sqlParams.deleteCharAt(sqlParams.length() - 1);
        // 最后加上删除的行数据id
        updateFields.add(fields[0]);
        String sql = String.format(MapperConstant.getUpdatePrefix(tableName) + "SET %s" + " WHERE id = ?", sqlParams);
        try (PreparedStatement prepareStatement = conn.prepareStatement(sql)) {
            for (int i = 0; i < updateFields.size(); i++) {
                prepareStatement.setObject(i + 1, updateFields.get(i));
            }
            prepareStatement.execute();
        } catch (Exception e) {
            conn.rollback();
            e.printStackTrace();
        }
    }

}
